import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from scipy import stats

from Voicelab.pipeline.Node import Node
from parselmouth.praat import call
from Voicelab.toolkits.Voicelab.VoicelabNode import VoicelabNode

###################################################################################################
# MEASURE JITTER NODE
# WARIO pipeline node for measuring the jitter of a voice.
###################################################################################################
# ARGUMENTS
# 'voice'   : sound file generated by parselmouth praat
###################################################################################################
# RETURNS
###################################################################################################


class MeasureJitterNode(VoicelabNode):
    def __init__(self, *args, **kwargs):
        """
        Args:
            *args:
            **kwargs:
        """
        super().__init__(*args, **kwargs)

        # initialize defaults
        self.args = {
            "start_time": 0,
            "end_time": 0,
            "shortest_period": 0.0001,
            "longest_period": 0.02,
            "maximum_period_factor": 1.3,
            "Measure PCA": True,
        }
        self.state = {
            "local_jitter_list": [],
            "localabsolute_jitter_list": [],
            "rap_jitter_list": [],
            "ppq5_jitter_list": [],
            "ddp_jitter_list": [],
        }

    ###############################################################################################
    # process: WARIO hook called once for each voice file.
    ###############################################################################################

    def end(self, results):

        """
        Args:
            results:
        """
        if self.args["Measure PCA"]:
            pca_results = self.jitter_pca()
            if pca_results is not None:
                for i, result in enumerate(results):
                    results[i][self]["PCA Result"] = pca_results[0]

        return results

    def process(self):

        """measure jitter"""

        voice = self.args["voice"]
        try:
        # Call the provided pitch bounds functions
            pitch_floor = self.args["Pitch Floor"]
            pitch_ceiling = self.args["Pitch Ceiling"]

            start_time = self.args["start_time"]
            end_time = self.args["end_time"]
            shortest_period = self.args["shortest_period"]
            longest_period = self.args["longest_period"]
            max_period_factor = self.args["maximum_period_factor"]

            point_process: object = call(
                voice, "To PointProcess (periodic, cc)", pitch_floor, pitch_ceiling
            )

            local_jitter: float = call(
                point_process,
                "Get jitter (local)",
                start_time,
                end_time,
                shortest_period,
                longest_period,
                max_period_factor,
            )

            # todo change this so that it accepts defaults unless user is in advanced mode
            localabsolute_jitter: float = call(
                point_process,
                "Get jitter (local, absolute)",
                start_time,
                end_time,
                shortest_period,
                longest_period,
                max_period_factor,
            )

            rap_jitter: float = call(
                point_process,
                "Get jitter (rap)",
                start_time,
                end_time,
                shortest_period,
                longest_period,
                max_period_factor,
            )

            ppq5_jitter: float = call(
                point_process,
                "Get jitter (ppq5)",
                start_time,
                end_time,
                shortest_period,
                longest_period,
                max_period_factor,
            )

            ddp_jitter: float = call(
                point_process,
                "Get jitter (ddp)",
                start_time,
                end_time,
                shortest_period,
                longest_period,
                max_period_factor,
            )

            self.state["local_jitter_list"] = local_jitter
            self.state["localabsolute_jitter_list"] = localabsolute_jitter
            self.state["rap_jitter_list"] = rap_jitter
            self.state["ppq5_jitter_list"] = ppq5_jitter
            self.state["ddp_jitter_list"] = ddp_jitter

            return {
                "Local Jitter": local_jitter,
                "Local Absolute Jitter": localabsolute_jitter,
                "RAP Jitter": rap_jitter,
                "ppq5 Jitter": ppq5_jitter,
                "ddp Jitter": ddp_jitter,
            }
        except:
            return {
                "Local Jitter": "Jitter Measurement Failed",
                "Local Absolute Jitter": "Jitter Measurement Failed",
                "RAP Jitter": "Jitter Measurement Failed",
                "ppq5 Jitter": "Jitter Measurement Failed",
                "ddp Jitter": "Jitter Measurement Failed",
            }

    def jitter_pca(self):

        local_jitter_list = self.state["local_jitter_list"]
        localabsolute_jitter_list = self.state["localabsolute_jitter_list"]
        rap_jitter_list = self.state["rap_jitter_list"]
        ppq5_jitter_list = self.state["ppq5_jitter_list"]
        ddp_jitter_list = self.state["ddp_jitter_list"]

        jitter_data = pd.DataFrame(
            np.column_stack(
                [
                    local_jitter_list,
                    localabsolute_jitter_list,
                    rap_jitter_list,
                    ppq5_jitter_list,
                    ddp_jitter_list,
                ]
            ),
            columns=[
                "localJitter",
                "localabsoluteJitter",
                "rapJitter",
                "ppq5Jitter",
                "ddpJitter",
            ],
        )

        # z-score the Jitter measurements
        measures = [
            "localJitter",
            "localabsoluteJitter",
            "rapJitter",
            "ppq5Jitter",
            "ddpJitter",
        ]
        try:
            x = jitter_data.loc[:, measures].values
            x = StandardScaler().fit_transform(x)
            # Run the PCA
            pca = PCA(n_components=1)
            principal_components = pca.fit_transform(x)

        except:
            principal_components = "PCA Failed"

        return principal_components
